{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yZmFWTnx_PQ-"
      },
      "source": [
        "## Evaluate E$^3$-S/32, with 8 experts, pre-trained on ILSVRC2021 and fine-tuned on CIFAR100"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8lQf6R1BOYlp"
      },
      "outputs": [],
      "source": [
        "import jax\n",
        "from jax import numpy as jnp\n",
        "import numpy as np\n",
        "import tensorflow_datasets as tfds\n",
        "from tqdm.auto import tqdm\n",
        "\n",
        "from vmoe.nn import models\n",
        "from vmoe.data import input_pipeline\n",
        "from vmoe.checkpoints import partitioned"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "V_6jHA1f_ef3"
      },
      "source": [
        "### Construct model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "J2MwqMchM1yY"
      },
      "outputs": [],
      "source": [
        "BATCH_SIZE = 1024    # Number of images processed in each step.\n",
        "NUM_CLASSES = 100     # Number of CIFAR100 classes.\n",
        "IMAGE_SIZE = 128     # Image size as input to the model.\n",
        "PATCH_SIZE = 32      # Patch size.\n",
        "NUM_LAYERS = 8      # Number of encoder blocks in the transformer.\n",
        "NUM_EXPERTS = 8      # Number of experts in each MoE layer.\n",
        "NUM_SELECTED_EXPERTS = 1  # Maximum number of selected experts per token.\n",
        "ENSEMBLE_SIZE = 2\n",
        "NUM_EXPERTS_PER_ENS_MEMBER = NUM_EXPERTS // ENSEMBLE_SIZE\n",
        "NUM_TOKENS_PER_IMAGE = (IMAGE_SIZE // PATCH_SIZE)**2 + 1\n",
        "NUM_DEVICES = 8\n",
        "GROUP_SIZE_PER_ENS_MEMBER = (BATCH_SIZE // NUM_DEVICES) * NUM_TOKENS_PER_IMAGE\n",
        "\n",
        "model_config = {\n",
        "      'name': 'VisionTransformerMoeEnsemble',\n",
        "      'num_classes': NUM_CLASSES,\n",
        "      'patch_size': (32, 32),\n",
        "      'hidden_size': 512,\n",
        "      'classifier': 'token',\n",
        "      'representation_size': 512,\n",
        "      'head_bias_init': -10.0,\n",
        "      'encoder': {\n",
        "          'num_layers': NUM_LAYERS,\n",
        "          'num_heads': 8,\n",
        "          'mlp_dim': 2048,\n",
        "          'dropout_rate': 0.0,\n",
        "          'attention_dropout_rate': 0.0,\n",
        "          'moe': {\n",
        "              'ensemble_size': ENSEMBLE_SIZE,\n",
        "              'num_experts': NUM_EXPERTS,\n",
        "              'group_size': GROUP_SIZE_PER_ENS_MEMBER * ENSEMBLE_SIZE,\n",
        "              'layers': (5, 7),\n",
        "              'dropout_rate': 0.0,\n",
        "              'router': {\n",
        "                  'num_selected_experts': NUM_SELECTED_EXPERTS,\n",
        "                  'noise_std': 1.0,  # This is divided by NUM_EXPERTS.\n",
        "                  'importance_loss_weight': 0.005,\n",
        "                  'load_loss_weight': 0.005,\n",
        "                  'dispatcher': {\n",
        "                      'name': 'einsum',\n",
        "                      'bfloat16': True,\n",
        "                      # If we have group_size tokens per group, with a balanced\n",
        "                      # router, the expected number of tokens per expert is:\n",
        "                      # group_size * num_selected_experts / num_experts.\n",
        "                      # To account for deviations from the average, we give some\n",
        "                      # multiplicative slack to this expected number:\n",
        "                      'capacity_factor': 1.5,\n",
        "                      # This is used to hint pjit about how data is distributed\n",
        "                      # at the input/output of each MoE layer.\n",
        "                      # This value means that the tokens are partitioned across\n",
        "                      # all devices in the mesh (i.e. fully data parallelism).\n",
        "                      'partition_spec': (('expert', 'replica'),),\n",
        "                      # We don't use batch priority for training/fine-tuning.\n",
        "                      'batch_priority': False,\n",
        "                  },\n",
        "              },\n",
        "          },\n",
        "      },\n",
        "  }"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vQvtYHxiO596"
      },
      "outputs": [],
      "source": [
        "model_cls = getattr(models, model_config.pop('name'))\n",
        "model = model_cls(deterministic=True, **model_config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "djlprye0_oOd"
      },
      "source": [
        "### Load weights"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wcDwZbA3coMY"
      },
      "outputs": [],
      "source": [
        "# Path to the fine-tuned checkpoint.\n",
        "checkpoint_prefix = 'gs://vmoe_checkpoints/eee_s32_last2_ilsvrc2012_ft_cifar100'\n",
        "mesh = partitioned.Mesh(np.asarray(jax.devices()), ('d',))\n",
        "checkpoint = partitioned.restore_checkpoint(\n",
        "    prefix=checkpoint_prefix, tree=None, axis_resources=None, mesh=mesh)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v2yxvTOt_r-8"
      },
      "source": [
        "### Create dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yt4gIiYRTFog"
      },
      "outputs": [],
      "source": [
        "process = f'keep(\"image\", \"label\")|decode|resize({IMAGE_SIZE}, inkey=\"image\")|value_range(-1,1)'\n",
        "\n",
        "dataset = input_pipeline.get_dataset(\n",
        "    variant='test',\n",
        "    name='cifar100',\n",
        "    split='test',\n",
        "    batch_size=BATCH_SIZE,\n",
        "    process=process,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S0LkAZ_o_vYR"
      },
      "source": [
        "### Run evaluation loop"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Bs2eYPv9bjcJ"
      },
      "outputs": [],
      "source": [
        " ncorrect = 0\n",
        " ntotal = 0\n",
        " for batch in dataset:\n",
        "  # The final batch has been padded with fake examples so that the batch size is\n",
        "  # the same as all other batches. The mask tells us which examples are fake.\n",
        "  mask = batch['__valid__']\n",
        "\n",
        "  logits, _ = model.apply({'params': checkpoint}, batch['image'])\n",
        "  # logits shape: (BATCH_SIZE * ENSEMBLE_SIZE, NUM_CLASSES).\n",
        "  logits = jnp.reshape(logits, (-1, ENSEMBLE_SIZE, NUM_CLASSES))\n",
        "  # Note: In the paper, we describe the implementation of E^3 with a jnp.tile\n",
        "  # mechanism. In this open-sourced version of the code, because of the pjit\n",
        "  # backend, we use jnp.repeat instead of jnp.tile for efficiency reasons. This\n",
        "  # explains the reshaping above that, in the paper implementation, would have\n",
        "  # been jnp.reshape(logits, (ENSEMBLE_SIZE, -1, NUM_CLASSES) followed by\n",
        "  # jax.nn.logsumexp(log_p, axis=0).\n",
        "  log_p = jax.nn.log_softmax(logits)\n",
        "  mean_log_p = jax.nn.logsumexp(log_p, axis=1) - jnp.log(ENSEMBLE_SIZE)\n",
        "\n",
        "  preds = jnp.argmax(mean_log_p, axis=1)\n",
        "  ncorrect += jnp.sum((preds == batch['label']) * mask)\n",
        "  ntotal += jnp.sum(mask)\n",
        "\n",
        "print(f'Test accuracy: {ncorrect / ntotal * 100:.2f}%')  # Should be 81.26%."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "provenance": [
        {
          "file_id": "1etGGfOO1WjJmEDyWqdNz5XgSMsVb5Z_d",
          "timestamp": 1662960558723
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
